import torch
def f(W, Y, x):
    """quadratic function : f(\theta) = \|W\theta - y\|_2^2"""

    return ((torch.matmul(W, x.unsqueeze(-1)).squeeze() - Y) ** 2).sum(dim=1).mean(dim=0)